Semantic segmentation for Usage and Coverage mapping with Pytorch and TorchGeo¶

image.png

TorchGeo is a PyTorch domain library that provides datasets, samplers, transformations, and pre-trained models specific to geospatial data.

In traditional computer vision datasets like ImageNet, the image files themselves tend to be quite simple and easy to work with. Most images have 3 spectral bands (RGB), are stored in common file formats like PNG or JPEG, and can be easily loaded with popular software libraries like PIL or OpenCV. Each image in these datasets is usually small enough to pass directly to a neural network. Furthermore, most of these datasets contain a finite number of well-selected images that are assumed to be independent and identically distributed, making train-val-test splits straightforward. As a result of this relative homogeneity, the same pre-trained models (e.g., CNNs pre-trained on ImageNet) have been shown to be effective on a wide range of vision tasks using transfer learning methods. Existing libraries like torchvision handle these simple cases well and have been used to make major advances in vision tasks over the past decade.

Remote sensing imagery is not as uniform. Instead of simple RGB images, satellites tend to capture multispectral images (Landsat 8 has 11 spectral bands) or even hyperspectral images (Hyperion has 242 spectral bands). These images capture information across a wider range of wavelengths (400 nm–15 µm), well outside the visible spectrum. Different satellites also have very different spatial resolutions – GOES has a resolution of 4 km/px, Maxar imagery is 30 cm/px, and drone imagery can be as low as 7 mm/px. These datasets almost always have a temporal component, with satellite revisions occurring daily, weekly, or biweekly. The images often overlap with other images in the dataset and need to be stitched together based on geographic metadata. These images tend to be very large (e.g. 10K x 10K pixels), so it is not possible to feed an entire image through a neural network. This data is distributed in hundreds of different raster and vector file formats, such as GeoTIFF and ESRI Shapefile, requiring specialized libraries such as GDAL to be loaded.

TorchGeo is designed to have the same API as other PyTorch domain libraries such as torchvision, torchtext, and torchaudio. If you already use torchvision in your workflow for computer vision datasets, you can switch to TorchGeo with just a few lines of code changes. All TorchGeo datasets and samplers are compatible with the PyTorch DataLoader class, which means you can take advantage of wrapper libraries like PyTorch Lightning for distributed training. In the following sections, we will explore possible use cases for TorchGeo to show how simple it is to use.

Let's perform the installation:

In [ ]:
!pip install -q git+https://github.com/microsoft/torchgeo.git
!pip install -q GPUtil
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
  Preparing metadata (pyproject.toml) ... done
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 kB 1.2 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 705.7/705.7 kB 13.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 661.6/661.6 kB 49.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 72.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.3/21.3 MB 76.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 69.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.7/106.7 kB 12.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 85.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 764.8/764.8 kB 62.7 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.5/154.5 kB 19.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 137.6/137.6 kB 17.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 727.7/727.7 kB 60.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB 8.7 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.3/71.3 kB 10.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.2/66.2 kB 8.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 660.0/660.0 kB 53.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 5.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 69.9/69.9 kB 9.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.5/59.5 kB 7.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.9/129.9 kB 15.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 kB 6.3 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Preparing metadata (setup.py) ... done
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 80.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.8/294.8 kB 31.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 77.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 kB 8.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.5/79.5 kB 10.2 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 117.0/117.0 kB 14.6 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.4/58.4 kB 7.7 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.3/58.3 kB 6.7 MB/s eta 0:00:00
  Building wheel for torchgeo (pyproject.toml) ... done
  Building wheel for efficientnet-pytorch (setup.py) ... done
  Building wheel for pretrainedmodels (setup.py) ... done
  Building wheel for antlr4-python3-runtime (setup.py) ... done
  Preparing metadata (setup.py) ... done
  Building wheel for GPUtil (setup.py) ... done

Let's import the functions and connect with Drive:

In [ ]:
import os
import tempfile

from torch.utils.data import DataLoader

from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler
In [ ]:
import lightning.pytorch as pl # Instead of import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchgeo.datasets import stack_samples, RasterDataset
from torchgeo.datasets.splits import random_bbox_assignment
from torchgeo.samplers import RandomGeoSampler, RandomBatchGeoSampler, GridGeoSampler
import os
import matplotlib.pyplot as plt
import numpy as np
from torchgeo.trainers import SemanticSegmentationTask
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
import ssl
import multiprocessing as mp
from torchgeo.datamodules import GeoDataModule
from typing import Type
import albumentations as A
import timeit
import torch
import numpy as np
from rasterio.plot import show
from rasterio.merge import merge
import rasterio
from rasterio.transform import from_bounds, from_origin
from rasterio.crs import CRS
from rasterio.io import MemoryFile
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

Chesapeake Land Cover¶

This dataset contains high-resolution aerial imagery from the USDA NAIP program, high-resolution land cover labels from the Chesapeake Conservancy, low-resolution land cover labels from the USGS NLCD 2011 dataset, low-resolution multispectral imagery from Landsat 8, and high-resolution building footprint masks from Microsoft Bing, formatted to accelerate machine learning research in land cover mapping. The Chesapeake Conservancy spent over 10 months and $1.3 million creating a robust, six-class land cover dataset covering the Chesapeake Bay watershed. While the goal of the Chesapeake Conservancy mapping effort was to create land cover data for use in conservation efforts, the same data can be used to train machine learning models that can be applied to even larger areas.

Organizing this dataset will allow users to easily test questions related to this geographic generalization problem, namely how to train machine learning models that can be applied to even broader areas. For example, this dataset could be used to directly estimate how well a model trained on Maryland data can generalize to the rest of the Chesapeake Bay.

image.png

The dataset consists of four NAIP orthomosaics and four files representing the labels:

image.png

We can define the dataset classes according to the color map and documentation codes:

In [ ]:
"""
Complete 13-class dataset.

    This version of the dataset is composed of 13 classes:

    0. No Data: Background values
    1. Water: All areas of open water including ponds, rivers, and lakes
    2. Wetlands: Low vegetation areas located along marine or estuarine regions
    3. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height
    4. Shrubland: Heterogeneous woody vegetation including shrubs and young trees
    5. Low Vegetation: Plant material less than 2 meters in height including lawns
    6. Barren: Areas devoid of vegetation consisting of natural earthen material
    7. Structures: Human-constructed objects made of impervious materials
    8. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height
    9. Impervious Roads: Impervious surfaces that are used for transportation
    10. Tree Canopy over Structures: Tree cover overlapping impervious structures
    11. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces
    12. Tree Canopy over Impervious Roads: Tree cover overlapping impervious roads
    13. Aberdeen Proving Ground: U.S. Army facility with no labels
"""

names=[
    'No Data',
    'Water',
    'Wetlands',
    'Tree Canopy',
    'Shrubland',
    'Low Vegetation',
    'Barren',
    'Structures',
    'Impervious Surfaces',
    'Impervious Roads',
    'Tree Canopy over Structures',
    'Tree Canopy over Impervious Surfaces',
    'Tree Canopy over Impervious Roads',
    'Aberdeen Proving Ground',
]

# subclasses use the 13 class cmap by default
cmap = [
    (0, 0, 0, 0),
    (0, 197, 255, 255),
    (0, 168, 132, 255),
    (38, 115, 0, 255),
    (76, 230, 0, 255),
    (163, 255, 115, 255),
    (255, 170, 0, 255),
    (255, 0, 0, 255),
    (156, 156, 156, 255),
    (0, 0, 0, 255),
    (115, 115, 0, 255),
    (230, 230, 0, 255),
    (255, 255, 115, 255),
    (197, 0, 255, 255),
]

Let's also define the architecture and its parameters:

In [ ]:
EPOCHS = 15
LR = 1e-4

IN_CHANNELS = 4 # NAIP dataset has 4 bands
NUM_CLASSES = len(names) # Chesapeake dataset has 13 classes
IMG_SIZE = 256
BATCH_SIZE = 8
SAMPLE_SIZE = 500

PATIENCE = 5
SEGMENTATION_MODEL = 'deeplabv3+' # only supports 'unet', 'deeplabv3+' and 'fcn'
#BACKBONE = 'se_resnet50'
BACKBONE = 'resnet50' # supports TIMM encoders (https://smp.readthedocs.io/en/latest/encoders_timm.html)
WEIGHTS = 'imagenet'
LOSS = 'focal' # supports ‘ce’, ‘jaccard’ or ‘focal’ loss


DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
WORKERS = mp.cpu_count()
print(f'Running on {NUM_DEVICES} {DEVICE}(s)')
Running on 1 cuda(s)

Then, we define the dataset path and instantiate the semantic segmentation task:

In [ ]:
ssl._create_default_https_context = ssl._create_unverified_context

OUTPUT_DIR = '/content/'
INPUT_DIR = '/content/drive/MyDrive/Datasets/Naip_chesapeak'

TEST_DIR = os.path.join(OUTPUT_DIR, "test")
if not os.path.exists(TEST_DIR):
    os.makedirs(TEST_DIR)

logger = CSVLogger(
    TEST_DIR,
    name='torchgeo_logs'
)

checkpoint_callback = ModelCheckpoint(
    every_n_epochs=1,
    dirpath=TEST_DIR,
    filename='torchgeo_trained'
)

task = SemanticSegmentationTask(
    model = SEGMENTATION_MODEL,
    backbone = BACKBONE,
    weights = True, # to use imagenet. Before we should define weights='imagenet'
    in_channels = IN_CHANNELS,
    num_classes = NUM_CLASSES,
    loss = LOSS,
    ignore_index = None,
    learning_rate = LR,
    learning_rate_schedule_patience = PATIENCE,
)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 83.2MB/s]

We also define a function to run the training:

In [ ]:
trainer = pl.Trainer(
        accelerator=DEVICE,
        devices=NUM_DEVICES,
        max_epochs=EPOCHS,
        callbacks=[checkpoint_callback, ],
        logger=logger,
    )
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs

Let's start loading the data. First we create the classes for the Images and Labels:

In [ ]:
class NAIPImages(RasterDataset):
    filename_glob = "m_*.tif"
    is_image = True
    separate_files = False

class ChesapeakeLabels(RasterDataset):
    filename_glob = "m_*.tif"
    is_image = False
    separate_files = False

Then we create a data augmentation and load the image and label data:

In [ ]:
data_augmentation_transform = A.Compose([
    A.Flip(),
    A.ShiftScaleRotate(),
    A.OneOf([
        A.RandomBrightness(),
        A.RandomGamma(),
    ]),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=5)
])

naip_root = os.path.join(INPUT_DIR, 'naip_images')
naip_images = NAIPImages(
    root=naip_root,
    #transforms=data_augmentation_transform,
)
print(naip_images)

chesapeake_root = os.path.join(INPUT_DIR, "chesapeake_labels")
chesapeake_labels = ChesapeakeLabels(
    root=chesapeake_root,
    #transforms=data_augmentation_transform,
)
print(chesapeake_labels)
/usr/local/lib/python3.10/dist-packages/albumentations/augmentations/transforms.py:1258: FutureWarning: This class has been deprecated. Please use RandomBrightnessContrast
  warnings.warn(
NAIPImages Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=440002.8, maxx=451549.7999999998, miny=4288884.0, maxy=4303429.2, mint=0.0, maxt=9.223372036854776e+18)
    size: 4
ChesapeakeLabels Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=440002.9842051893, maxx=451549.3842051893, miny=4288884.327921044, maxy=4303428.927921044, mint=0.0, maxt=9.223372036854776e+18)
    size: 4

In this way, we join the image and label datasets and apply a sampler that will generate patches according to the size and number of samples that we defined previously.

In [ ]:
dataset = naip_images & chesapeake_labels # this means I'm creating an IntersectionDataset
sampler = RandomGeoSampler(dataset, size=IMG_SIZE, length=SAMPLE_SIZE)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
Converting ChesapeakeLabels res from 0.6000000000000023 to 0.6

Let's now create the complete dataset class that will create the division of the data into training, testing and validation:

In [ ]:
class CustomGeoDataModule(GeoDataModule):
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)

        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

datamodule = CustomGeoDataModule(
    dataset_class = type(dataset), # GeoDataModule kwargs
    batch_size = BATCH_SIZE, # GeoDataModule kwargs
    patch_size = IMG_SIZE, # GeoDataModule kwargs
    length = SAMPLE_SIZE, # GeoDataModule kwargs
    num_workers = WORKERS, # GeoDataModule kwargs
    dataset1 = naip_images, # IntersectionDataset kwargs
    dataset2 = chesapeake_labels, # IntersectionDataset kwargs
    collate_fn = stack_samples, # IntersectionDataset kwargs
)

We can see some examples of generated images and labels:

In [ ]:
def colour_code_segmentation(image, colors):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        colormap: the list os rgb colors for each class

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(colors)
    return colour_codes[image]
In [ ]:
import torchvision.transforms as transforms
reverse_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size=IMG_SIZE),
        ])

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

n = 0
for sample in dataloader:
    if n == 10:
        break
    image, gt_mask = sample['image'], sample['mask']

    gt_mask = colour_code_segmentation(gt_mask, cmap)

    visualize(
        image=reverse_transform(image.squeeze()[:3]),
        ground_truth = gt_mask.squeeze(),
    )
    n += 1
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

With everything ready, let's run the model training and validation:

In [ ]:
start = timeit.default_timer() # Measuring the time

checkpoint_file = os.path.join(TEST_DIR, 'torchgeo_trained.ckpt')

if os.path.isfile(checkpoint_file):
    print('Resuming training from previous checkpoint...')
    trainer.fit(
        model=task,
        datamodule=datamodule,
        ckpt_path=checkpoint_file
    )
else:
    print('Starting training from scratch...')
    trainer.fit(
        model=task,
        datamodule = datamodule,
    )

print("The time taken to train was: ", timeit.default_timer() - start)
Starting training from scratch...
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:617: UserWarning: Checkpoint directory /content/test exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name          | Type             | Params
---------------------------------------------------
0 | model         | DeepLabV3Plus    | 26.7 M
1 | loss          | FocalLoss        | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
26.7 M    Trainable params
0         Non-trainable params
26.7 M    Total params
106.736   Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type             | Params
---------------------------------------------------
0 | model         | DeepLabV3Plus    | 26.7 M
1 | loss          | FocalLoss        | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
26.7 M    Trainable params
0         Non-trainable params
26.7 M    Total params
106.736   Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
INFO: `Trainer.fit` stopped: `max_epochs=15` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.
The time taken to train was:  338.4465540860001

Once trained, let's apply it to the test data:

In [ ]:
trainer.test(
    model=task,
    datamodule=datamodule
)
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: 0it [00:00, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃         Test metric         ┃        DataLoader 0         ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│   test_MulticlassAccuracy   │     0.8992854952812195      │
│ test_MulticlassJaccardIndex │     0.8170016407966614      │
│          test_loss          │    8.787868864601478e-06    │
└─────────────────────────────┴─────────────────────────────┘
Out[ ]:
[{'test_loss': 8.787868864601478e-06,
  'test_MulticlassAccuracy': 0.8992854952812195,
  'test_MulticlassJaccardIndex': 0.8170016407966614}]

Once trained and validated, we will use one of the NAIP orthomosaics to make the full prediction. First, we load the orthomosaic, divide it into patches, and perform the prediction for each patch. Then, we apply a Merge to the predicted patches and generate the georeferenced result:

In [ ]:
def create_in_memory_geochip(predicted_chip, geotransform, crs):
    """
    Apply georeferencing to the predicted chip.

    Parameters:
        predicted_chip (numpy array): The predicted segmentation chip (e.g., binary mask).
        geotransform (tuple): A tuple containing the geotransformation information of the chip (x-coordinate of the top left corner, x and y pixel size, rotation, y-coordinate of the top left corner, and rotation).
        crs (str): Spatial Reference System (e.g., EPSG code) of the chip.

    Return:
        A rasterio dataset that is georreferenced.
    """
    memfile = MemoryFile()
    dataset = memfile.open(
        driver='GTiff',
        height=predicted_chip.shape[0],
        width=predicted_chip.shape[1],
        count=predicted_chip.shape[2],  # Number of bands
        dtype=np.uint8,
        crs=crs,
        transform=geotransform,
        photometric='RGBA',
    )

    rolled_array = np.rollaxis(predicted_chip, axis=2) # putting the bands first
    dataset.write(rolled_array)
    return dataset
In [ ]:
def georreferenced_chip_generator(dataloader, model, crs, pixel_size, colors):
    """
    Apply georeferencing to the predicted chip.

    Parameters:
        dataloader (torch.utils.data.Dataloader): Dataloader with the data to be predicted.
        model (an https://github.com/qubvel/segmentation_models.pytorch model): model used for inference.
        crs (str): Spatial Reference System (e.g., EPSG code) of the chip.
        pixel_size (float): Pixel dimensoion in map units.

    Yields:
        A georeferenced numpy array of the predicted output.
    """
    georref_chips_list = []
    for i, sample in enumerate(dataloader):
        '''
        if i == 10:
            break
        '''

        image, gt_mask, bbox = sample['image'], sample['mask'], sample['bbox'][0]

        image = image/255. # as I'm not using a GeoDatamodule, I need to divide de images by 255 manually

        prediction = model.predict(image)
        prediction = torch.softmax(prediction, dim=1)
        prediction = torch.argmax(prediction, dim = 1)

        # reapplying the original colors in the reversed one hot images
        #gt_mask = colour_code_segmentation(gt_mask, colors)
        prediction = colour_code_segmentation(prediction, colors)

        geotransform = from_origin(bbox.minx, bbox.maxy, pixel_size, pixel_size)
        #yield create_in_memory_geochip(prediction.squeeze(), geotransform, crs)
        georref_chips_list.append(create_in_memory_geochip(prediction.squeeze(), geotransform, crs))
    return georref_chips_list
In [ ]:
def merge_georeferenced_chips(chips_generator, output_path):
    """
    Merge a list of georeferenced chips into a single GeoTIFF file.

    Parameters:
        chips_generator (generator): A generator of Rasterio datasets representing the georeferenced chips.
        output_path (str): The path where the merged GeoTIFF file will be saved.

    Returns:
        None
    """
    #chips_list = list(chips_generator)
    chips_list = chips_generator
    # Merge the chips using Rasterio's merge function
    merged, merged_transform = merge(chips_list)

    # Calculate the number of rows and columns for the merged output
    rows, cols = merged.shape[1], merged.shape[2]

    # Update the metadata of the merged dataset
    merged_metadata = chips_list[0].meta
    merged_metadata.update({
        'height': rows,
        'width': cols,
        'transform': merged_transform
    })

    # Write the merged array to a new GeoTIFF file
    with rasterio.open(output_path, 'w', **merged_metadata) as dst:
        dst.write(merged)

    for chip in chips_list:
        chip.close()
In [ ]:
test_dataset = datamodule.test_dataset
test_sampler = GridGeoSampler(test_dataset, 2048, 2048)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, collate_fn=stack_samples)

pixel_size = test_dataset.res
crs = test_dataset.crs.to_epsg()

start = timeit.default_timer() # Measuring the time
chips_generator = georreferenced_chip_generator(test_dataloader, task.model, crs, pixel_size, cmap)
print("The time taken to predict was: ", timeit.default_timer() - start)

start = timeit.default_timer() # Measuring the time
file_name = os.path.join(OUTPUT_DIR, 'merged_prediction.tif')
merge_georeferenced_chips(chips_generator, file_name)
print("The time taken to generate a georrefenced image and save it was: ", timeit.default_timer() - start)
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/c89d7f5b-0a75-4070-a851-677aedbeffaf/c89d7f5b-0a75-4070-a851-677aedbeffaf.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/ed7ad111-75d9-440e-9e08-3784b05a3ff1/ed7ad111-75d9-440e-9e08-3784b05a3ff1.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/c4ecb927-2cf8-4047-89bf-e8c71f7b7e28/c4ecb927-2cf8-4047-89bf-e8c71f7b7e28.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/ea0f8fb0-8e69-4338-9cd1-8a72a0881c70/ea0f8fb0-8e69-4338-9cd1-8a72a0881c70.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/9a41f19e-0ed2-49e1-bdd3-f8e6d3304bee/9a41f19e-0ed2-49e1-bdd3-f8e6d3304bee.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/3f55bf7a-70b5-495c-8f71-131d7d24c081/3f55bf7a-70b5-495c-8f71-131d7d24c081.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/a19357db-c965-4057-ac64-562d99059be9/a19357db-c965-4057-ac64-562d99059be9.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/a4780b8b-74ec-468f-80ba-69c55cc2f0ac/a4780b8b-74ec-468f-80ba-69c55cc2f0ac.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/fedb29ca-25ec-4693-8266-86b57096d494/fedb29ca-25ec-4693-8266-86b57096d494.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/22ed1574-a1b3-4963-817d-be226113a100/22ed1574-a1b3-4963-817d-be226113a100.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/acc3075a-d979-48fa-8209-181da7f1d152/acc3075a-d979-48fa-8209-181da7f1d152.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/75519ad0-056f-41d0-8ce7-9bd816e4a90b/75519ad0-056f-41d0-8ce7-9bd816e4a90b.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/20f0df83-5b9f-45f3-a537-7c595d4138d1/20f0df83-5b9f-45f3-a537-7c595d4138d1.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/dbbf2908-f87a-4123-b5f7-547e5d909aa0/dbbf2908-f87a-4123-b5f7-547e5d909aa0.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/f997e864-b486-48c9-890a-ee3b0aafaa7c/f997e864-b486-48c9-890a-ee3b0aafaa7c.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/de2b294e-e9f9-46e5-a2f3-54bf4f2b2bce/de2b294e-e9f9-46e5-a2f3-54bf4f2b2bce.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/db7c3fca-a51a-4c7a-a7f2-4bafca3fcd51/db7c3fca-a51a-4c7a-a7f2-4bafca3fcd51.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/28b70ebb-71f7-4131-8d43-8482805a7722/28b70ebb-71f7-4131-8d43-8482805a7722.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/2cd4f524-4392-4ac2-b8a4-db7927628ab9/2cd4f524-4392-4ac2-b8a4-db7927628ab9.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/88b0e68c-fa16-48ea-983b-f11ffd9362b7/88b0e68c-fa16-48ea-983b-f11ffd9362b7.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/17e57cf5-3367-43fe-90e2-eb41d69aedef/17e57cf5-3367-43fe-90e2-eb41d69aedef.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/d6c1e554-a571-48df-8b7b-ea22d8c9f0ca/d6c1e554-a571-48df-8b7b-ea22d8c9f0ca.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/1a69cd94-82ae-4b9a-b29e-c1846fca7c27/1a69cd94-82ae-4b9a-b29e-c1846fca7c27.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/646a4121-5f37-42c8-88c1-c922e0711dfb/646a4121-5f37-42c8-88c1-c922e0711dfb.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/05da5d7f-01cd-4ce8-bca8-44070b423eac/05da5d7f-01cd-4ce8-bca8-44070b423eac.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/ff9c5a14-e55a-48a0-995b-c4e5c6280843/ff9c5a14-e55a-48a0-995b-c4e5c6280843.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/e259c998-d9d1-4e05-9f6d-88db34a24073/e259c998-d9d1-4e05-9f6d-88db34a24073.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/c2f1382f-59fd-4ae8-b8ad-95f92393db9c/c2f1382f-59fd-4ae8-b8ad-95f92393db9c.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/19542c5d-5e01-4b99-8018-3ad64399857b/19542c5d-5e01-4b99-8018-3ad64399857b.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/82cba50f-dd8b-4384-92fc-a8ca6a46ba1f/82cba50f-dd8b-4384-92fc-a8ca6a46ba1f.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/af5ac656-ad8f-4ec5-9cb5-35f1cc00eba5/af5ac656-ad8f-4ec5-9cb5-35f1cc00eba5.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/db84f963-195b-4e28-84d0-25b4c10e5f18/db84f963-195b-4e28-84d0-25b4c10e5f18.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/cf819562-1b4d-4678-b46f-9dd23da907d1/cf819562-1b4d-4678-b46f-9dd23da907d1.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/c522a3c8-46af-4c3a-bc80-da8d6855ad99/c522a3c8-46af-4c3a-bc80-da8d6855ad99.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
WARNING:rasterio._env:CPLE_NotSupported in 'RGBA' is an unexpected value for PHOTOMETRIC creation option of type string-select.
WARNING:rasterio._env:CPLE_IllegalArg in /vsimem/daee4a90-5312-4ec7-b987-1ce0ba4b2a81/daee4a90-5312-4ec7-b987-1ce0ba4b2a81.tif: PHOTOMETRIC=RGBA value not recognised, ignoring.  Set the Photometric Interpretation as MINISBLACK.
The time taken to predict was:  418.4003023790001
The time taken to generate a georrefenced image and save it was:  4.671972782000012
In [ ]:
output_filepath = os.path.join(OUTPUT_DIR, 'merged_prediction.tif')
src = rasterio.open(output_filepath)
show(src.read(), transform=src.transform)
No description has been provided for this image
Out[ ]:
<Axes: >

Reference:

https://www.kaggle.com/code/luizclaudioandrade/torchgeo-101

https://torchgeo.readthedocs.io/en/stable/tutorials/getting_started.html

https://pytorch.org/blog/geospatial-deep-learning-with-torchgeo/